#!/usr/bin/env python3

from subprocess import Popen, PIPE, TimeoutExpired, check_output
from random import shuffle
from pprint import pformat
import signal
import os
from multiprocessing import Pool, Value, cpu_count
import copy
import time
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import logging
import math

log = logging.getLogger("repair-tsd")
log.setLevel(logging.INFO)
ch = logging.StreamHandler()
logformat = '%(asctime)s %(name)s %(levelname)s %(message)s'
formatter = logging.Formatter(logformat)
ch.setFormatter(formatter)
log.addHandler(ch)
metric_count = Value('i', 0)
failed_count = Value('i', 0)


def get_large_divisors(num):
    """
    Get "large" divisors of num
    i.e. only numbers that, when multiplied, are > than num

    e.g. 60 -> 10, 12, 15, 20, 30, 60

    * Because while 1, 2, 3, 4, 6 are all divisors, they are too small
      when mutiplied together

    :param int num: The number to check
    :returns: all large divisors
    :rtype: list
    """
    return [int(num / i) for i in range(1, int(math.sqrt(num)) + 1)
            if num % i == 0 and i * i != num]


def get_metrics(args):
    """
    Collect all metrics from OpenTSDB

    :returns: all metrics
    :rtype: list
    """
    time_chunk = args.get("time_chunk", 60)
    multiplier = int(60 / time_chunk)
    time_range = args.get("time_range", 48)
    tsd_path = args.get("tsd_path", "/usr/share/opentsdb/bin/tsdb")
    cfg_path = args.get("cfg_path", "/etc/opentsdb/opentsdb.conf")
    use_sudo = args.get("use_sudo", False)
    sudo_user = args.get("sudo_user", "opentsdb")
    base = "{} fsck --config={}".format(tsd_path, cfg_path)
    check_cmd = "{} uid --config={} grep metrics".format(tsd_path, cfg_path)
    if use_sudo:
        base = "sudo -u {} {}".format(sudo_user, base)
        check_cmd = "sudo -u {} {}".format(sudo_user, check_cmd)
    cmd = '{} uid --config={} grep metrics ".*"'.format(tsd_path,
                                                        cfg_path)
    proc = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
    results = proc.communicate()
    metriclist = [m.split(" ")[1].strip(":")
                  for m in results[0].decode().split("\n") if m]
    metriclist = [m for m in metriclist if m and m != "\x00"]
    metricobj = {"time_chunk": time_chunk,
                 "timeout": int((time_chunk * 60) / 2),
                 "retries": args.get("retries", 1),
                 "compact": args.get("compact", False),
                 "multiplier": multiplier,
                 "metriccount": len(metriclist),
                 "chunk_count": time_range * multiplier,
                 "base": base, "check_cmd": check_cmd}
    metrics = []
    for m in metriclist:
        metric = copy.deepcopy(metricobj)
        metric["metric"] = m
        metrics.append(metric)
    if args.get("shuffle", False):
        shuffle(metrics)
    log.info("There are {} metrics to process".format(len(metrics)))
    return metrics


def _process_metric_chunk(metric, chunk, x, cmd, timeout):
    """
    Actually run the cli command to repair this chunk.
    we segment this so the calls to both the compact and
    regular runs can be somewhat consistently managed

    :param str metric: name of the metric we're processing
    :param int chunk: Which time segment we're processing
    :param int x: which attempt for the time segment we're on
    :param str cmd: The actual command we'll run
    :param int timeout: How long the command is allowed to run
    :returns: whether the command was successful
    :rtype: bool
    """
    log.debug("Running command: {}".format(cmd))
    metricproc = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE,
                       preexec_fn=os.setsid)
    try:
        results, err = metricproc.communicate(timeout=timeout)
    except TimeoutExpired:
        msg = "{}: chunk {} failed (timeout: {}) (run {})".format(metric,
                                                                  chunk,
                                                                  timeout,
                                                                  x)
        log.warning(msg)
        try:
            os.killpg(os.getpgid(metricproc.pid), signal.SIGTERM)
        except Exception as e:
            log.warning("Couldn't kill subprocess :: {}".format(e))
        return False
    except Exception as e:
        log.error("{} general exception :: {}".format(metric, e))
        try:
            os.killpg(os.getpgid(metricproc.pid), signal.SIGTERM)
        except Exception as e:
            log.warning("Couldn't kill subprocess :: {}".format(e))
        return False
    results = [r for r in results.decode().split("\n") if r][-26:]
    final_results = []
    """
    We'll only collect results that are non-0
    since we're not super interested in stuff that didn't change.
    """
    for r in results:
        # Strip the timestamp from the log line
        line = r.split(" ")[6:]
        try:
            if int(line[-1]) != 0:
                final_results.append(" ".join(line))
        except Exception:
            final_results.append(" ".join(line))
    result_str = "\n".join(final_results)
    log.debug("{} results:\n{}".format(metric, result_str))
    return True


def repair_metric_chunk(metricobj, chunk):
    """
    Repair one 'chunk' of data for a metric
    """
    metric = metricobj["metric"]
    time_chunk = metricobj["time_chunk"]
    base = metricobj["base"]
    timeout = metricobj["timeout"]
    chunk_count = metricobj["chunk_count"]
    compact = metricobj["compact"]
    log.debug("Running chunk {} for {}".format(chunk, metric))
    if chunk < 2:
        timestr = "{}m-ago".format(time_chunk)
    else:
        timestr = "{}m-ago {}m-ago".format((chunk + 1) * time_chunk,
                                           chunk * time_chunk)
    cmd = "{} {} sum".format(base, timestr)
    fullcmd = "{} {} --delete-bad-compacts --delete-bad-rows \
               --delete-bad-values --delete-unknown-columns \
               --delete-orphans --resolve-duplicates --fix".format(cmd,
                                                                   metric)
    ccmd = "{} {} --fix --compact".format(cmd, metric)
    """
    Even though we're chunking, it's worth trying things more than once
    """
    for x in range(1, metricobj["retries"] + 2):
        log.debug("Repair try {} for {}".format(x, timestr))
        if not _process_metric_chunk(metric, chunk, x, fullcmd, timeout * x):
            continue
        if compact:
            if not _process_metric_chunk(metric, chunk, x, ccmd, timeout * x):
                continue
        if chunk % 20 == 0:
            log.info("{} -> Chunk {} of {} finished".format(metric,
                                                            chunk,
                                                            chunk_count))
        return None
    else:
        log.error("Failed to completely repair {}".format(metric))
        return metric


def process_metric(metricobj):
    """
    Run fsck on a list of metrics over a time range
    """
    metric = metricobj["metric"]
    chunk_count = metricobj["chunk_count"]
    try:
        check_output("{} \"^{}$\"".format(metricobj["check_cmd"], metric),
                     shell=True)
    except Exception:
        log.warning("{} doesn't exist! Skipping...".format(metric))
        return None
    log.info("Repairing {} in {} chunks".format(metric, chunk_count))
    start_time = time.time()
    for x in range(1, chunk_count + 1):
        failed = repair_metric_chunk(metricobj, x)
        if failed:
            with failed_count.get_lock():
                failed_count.value += 1
            return metric
    runtime = time.time() - start_time
    with metric_count.get_lock():
        metric_count.value += 1
    line = "COMPLETE: {} repair took {} seconds".format(metric,
                                                        int(runtime))
    line += " ({} of {} metrics complete)".format(metric_count.value,
                                                  metricobj["metriccount"])
    line += " ({} failed)".format(failed_count.value)
    log.info(line)


def process_metrics(metric_list, threads):
    threads = Pool(threads)
    failed_metrics = threads.map(process_metric, metric_list)
    failed_metrics = [m for m in failed_metrics if m]
    return failed_metrics


def cli_opts():
    parser = ArgumentParser(description="Repair all OpenTSDB metrics",
                            formatter_class=ArgumentDefaultsHelpFormatter)
    parser.add_argument("--debug", action="store_true", default=False,
                        help="Show debug information")
    parser.add_argument("--time-range", default="48",
                        help="How many hours of time we collect to repair")
    parser.add_argument("--time-chunk", default="60",
                        help="How many minutes of data to scan per chunk")
    parser.add_argument("--retries", default="1",
                        help="How many times we should try failed metrics")
    parser.add_argument("--tsd-path", default="/usr/share/opentsdb/bin/tsdb",
                        help="Path to the OpenTSDB CLI binary")
    parser.add_argument("--cfg-path", default="/etc/opentsdb/opentsdb.conf",
                        help="Path to OpenTSDB config")
    parser.add_argument("--store-path", default="/tmp/opentsdb-fsck.list",
                        help="Path to OpenTSDB config")
    parser.add_argument("--use-sudo", action="store_true",
                        default=False,
                        help="switch user when running repairs?")
    parser.add_argument("--compact", action="store_true",
                        default=False,
                        help="Run compaction with repairs")
    parser.add_argument("--shuffle", action="store_true",
                        default=False,
                        help="Mix up incoming metric order")
    parser.add_argument("--threads", default="{}".format(int(cpu_count() / 2)),
                        help="Total number of metrics to process at once")
    parser.add_argument("--sudo-user", default="opentsdb",
                        help="User to switch to...")
    return parser.parse_args()


def main():
    args = cli_opts()
    chunks = get_large_divisors(60)
    if args.debug:
        log.setLevel(logging.DEBUG)
    try:
        time_range = int(args.time_range)
    except Exception as e:
        log.error("Invalid time range {} :: {}".format(args.time_range, e))
    try:
        retries = int(args.retries)
    except Exception as e:
        log.error("Invalid retry number {} :: {}".format(args.retries, e))
    try:
        threads = int(args.threads)
    except Exception as e:
        log.error("Invalid thread count {} :: {}".format(args.threads, e))
    try:
        time_chunk = int(args.time_chunk)
        if time_chunk < 60:
            if time_chunk not in chunks:
                raise ArithmeticError
        if time_chunk > 60:
            if not any(n for n in chunks if time_chunk % chunk == 0):
                raise ArithmeticError
    except Exception as e:
        log.error("Invalid time chunk {} :: {}".format(args.time_chunk, e))

    metric_list = get_metrics({"time_range": time_range,
                               "use_sudo": args.use_sudo,
                               "sudo_user": args.sudo_user,
                               "time_chunk": time_chunk,
                               "tsd_path": args.tsd_path,
                               "cfg_path": args.cfg_path,
                               "store_path": args.store_path,
                               "shuffle": args.shuffle,
                               "compact": args.compact,
                               "retries": retries})
    stime = time.time()
    failed = process_metrics(metric_list, threads)
    etime = time.time()
    log.info("Processed {} metrics in [{}] seconds".format(len(metric_list),
                                                           etime - stime))
    if len(failed) > 0:
        log.warning("{} failed metrics:\n{}".format(len(failed),
                                                    pformat(sorted(failed))))


if __name__ == "__main__":
    main()
